import numpy as np
import cv2

LEFTDIR = 1
RIGHTDIR = 2
#  get sift ,flann Machine
def getMachine():
    FLANN_INDEX_KDTREE = 1
    index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
    search_params = dict(checks=50)
    flann = cv2.FlannBasedMatcher(index_params, search_params)
    sift = cv2.xfeatures2d_SIFT().create()
    return sift,flann

def imgProcess(img,top,bot,left,right):
    imgBord = cv2.copyMakeBorder(img,top,bot,left,right,cv2.BORDER_CONSTANT,value=(0,0,0))
    imgGray = cv2.cvtColor(imgBord,cv2.COLOR_BGR2GRAY)
    return imgBord,imgGray

def findEdgeDot(img,x1,x2,y1,y2):
    dotsum = 0
    for i in range(x1,x2+1):
        for j in range(y1,y2+1):
            if not img.item(j,i):
                dotsum +=1 
    return dotsum 

def getSmallOuterRect(img):
    gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    thresh,binary=cv2.threshold(gray,1,255,cv2.THRESH_BINARY)
    image,contours,hierarchy = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    areaList = []
    for contour in contours:
        area = cv2.contourArea(contour)
        areaList.append(area)
    return cv2.boundingRect(contours[np.argmax(areaList)])

def getMaxInnerRect(img,step): # 输入的图像是二进制的
    gray = cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
    thresh,binary=cv2.threshold(gray,1,255,cv2.THRESH_BINARY)
    x = 0
    y = 0
    h,w = binary.shape
    topdot  =  findEdgeDot(binary,x,x+w-1,y,y)
    botdot  =  findEdgeDot(binary,x,x+w-1,y+h-1,y+h-1)
    lefdot  =  findEdgeDot(binary,x,x,y,y+h-1)
    rigdot  =  findEdgeDot(binary,x+w-1,x+w-1,y,y+h-1)
    edgedot = [topdot,botdot,lefdot,rigdot]
    while topdot or botdot or lefdot or rigdot :
        maxedge = max(edgedot)
        if maxedge == topdot:
            y += step
            h -= step 
        elif maxedge == botdot:
            h -= step
        elif maxedge == lefdot:
            x += step
            w -= step
        else:
            w -= step
        topdot  =  findEdgeDot(binary,x,x+w-1,y,y)
        botdot  =  findEdgeDot(binary,x,x+w-1,y+h-1,y+h-1)
        lefdot  =  findEdgeDot(binary,x,x,y,y+h-1)
        rigdot  =  findEdgeDot(binary,x+w-1,x+w-1,y,y+h-1)
        edgedot = [topdot,botdot,lefdot,rigdot]
    return x,y,w,h
            
def mergeImge(img1,img2,sift,flann):
    srcImg,img1gray = imgProcess(img1,img1.shape[0]//2,img1.shape[0]//2,img1.shape[1]//2,img1.shape[1]//2)
    testImg,img2gray= imgProcess(img2,img2.shape[0]//2,img2.shape[0]//2,img2.shape[1]//2,img2.shape[1]//2)
    
    # find the keypoints and descriptors with SIFT
    kp1, des1 = sift.detectAndCompute(img1gray, None)
    kp2, des2 = sift.detectAndCompute(img2gray, None)
    # FLANN parameters
    matches = flann.knnMatch(des1, des2, k=2)
    # Need to draw only good matches, so create a mask
    matchesMask = [[0, 0] for i in range(len(matches))]

    good = []
    pts1 = []
    pts2 = []
    # ratio test as per Lowe's paper
    for i, (m, n) in enumerate(matches):
        if m.distance < 0.7*n.distance:
            good.append(m)
            pts2.append(kp2[m.trainIdx].pt)
            pts1.append(kp1[m.queryIdx].pt)
            matchesMask[i] = [1, 0]

    # draw_params = dict(matchColor=(0, 255, 0),
    #                    singlePointColor=(255, 0, 0),
    #                    matchesMask=matchesMask,
    #                    flags=0)
    #img3 = cv2.drawMatchesKnn(img1gray, kp1, img2gray, kp2, matches, None, **draw_params)
    rows, cols = srcImg.shape[:2]
    MIN_MATCH_COUNT = 10
    if len(good) > MIN_MATCH_COUNT:
        src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2)
        dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2)
        M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
        warpImg = cv2.warpPerspective(testImg, np.array(M), (testImg.shape[0]*2, testImg.shape[1]*2), flags=cv2.WARP_INVERSE_MAP)
        direction = -1
        # overlap region
        for col in range(0, cols):
            if srcImg[:, col].any() and warpImg[:, col].any():
                left = col
                break
        if srcImg[:, left-1].any():
            direction = LEFTDIR
        else: 
            direction = RIGHTDIR

        for col in range(cols-1, 0, -1):
            if srcImg[:, col].any() and warpImg[:, col].any():
                right = col
                break

        # get max region
        res = np.zeros([rows, cols, 3], np.uint8)
        for row in range(0, rows):
            for col in range(0, cols):
                if not srcImg[row, col].any():
                    res[row, col] = warpImg[row, col]
                elif not warpImg[row, col].any():
                    res[row, col] = srcImg[row, col]
                else:
                    srcImgLen = float(abs(col - left))
                    testImgLen = float(abs(col - right))
                    alpha = 1- srcImgLen / (srcImgLen + testImgLen) # 离得越近权重越大
                    if direction == LEFTDIR:
                        alpha = 1-alpha
                    res[row, col] = np.clip(srcImg[row, col] * (1-alpha) + warpImg[row, col] * alpha, 0, 255)
 
        # opencv is bgr, matplotlib is rgb
        x,y,w,h = getSmallOuterRect(res)
        resImg = res[y:y+h,x:x+w]
        x,y,w,h = getMaxInnerRect(resImg,2)
        outimg = resImg[y:y+h,x:x+w]

        return (True,resImg,outimg)

        #res = cv2.cvtColor(res, cv2.COLOR_BGR2RGB)
        # show the result
        # plt.figure()
        # plt.imshow(res)
        # plt.show()
    else:
        return (False)


if __name__ == "__main__":
    img1 = cv2.imread("./img/test1.jpg")
    img2 = cv2.imread("./img/test2.jpg")
    sift,flann = getMachine()
    res = mergeImge(img1,img2,sift,flann)
    if(res[0]):
        cv2.imshow("res",res[2])
        cv2.waitKey()

